Hey guys, I just published the developer version of NanoDL, a library for developing transformer models within the Jax/Flax ecosystem and would love your feedback!
Key Features of NanoDL include:
A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.
An extensive selection of models like LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.
Data-parallel distributed trainers so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.
Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.
GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU.
Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models.
A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc.
Each model is contained in a single file with no external dependencies, so the source code can also be easily used.
Checkout the repository for sample usage and more details: https://github.com/HMUNACHI/nanodl
Ultimately, I want as many opinions as possible, next steps to consider, issues, even contributions.
Note: I am working on the readme docs. For now, in the source codes, I include a comprehensive example on top of each model file in comments.
submitted by /u/Henrie_the_dreamer
[link] [comments]